{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://storage.googleapis.com/bert_models/2018_10_18/cased_L-12_H-768_A-12.zip\n",
    "# !unzip cased_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/optimization.py:87: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import bert\n",
    "from bert import run_classifier\n",
    "from bert import optimization\n",
    "from bert import tokenization\n",
    "from bert import modeling\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import pandas as pd\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "BERT_VOCAB = 'cased_L-12_H-768_A-12/vocab.txt'\n",
    "BERT_INIT_CHKPNT = 'cased_L-12_H-768_A-12/bert_model.ckpt'\n",
    "BERT_CONFIG = 'cased_L-12_H-768_A-12/bert_config.json'\n",
    "\n",
    "tokenization.validate_case_matches_checkpoint(True, '')\n",
    "tokenizer = tokenization.FullTokenizer(\n",
    "      vocab_file=BERT_VOCAB, do_lower_case=False)\n",
    "MAX_SEQ_LENGTH = 120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['train_X', 'train_Y', 'test_X', 'test_Y'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "with open('text.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
    "    while True:\n",
    "        total_length = len(tokens_a) + len(tokens_b)\n",
    "        if total_length <= max_length:\n",
    "            break\n",
    "        if len(tokens_a) > len(tokens_b):\n",
    "              tokens_a.pop()\n",
    "        else:\n",
    "              tokens_b.pop()\n",
    "\n",
    "def get_data(left, right):\n",
    "    input_ids, input_masks, segment_ids = [], [], []\n",
    "    for i in tqdm(range(len(left))):\n",
    "        tokens_a = tokenizer.tokenize(left[i])\n",
    "        tokens_b = tokenizer.tokenize(right[i])\n",
    "        _truncate_seq_pair(tokens_a, tokens_b, MAX_SEQ_LENGTH - 3)\n",
    "        tokens = []\n",
    "        segment_id = []\n",
    "        tokens.append(\"[CLS]\")\n",
    "        segment_id.append(0)\n",
    "        for token in tokens_a:\n",
    "            tokens.append(token)\n",
    "            segment_id.append(0)\n",
    "        tokens.append(\"[SEP]\")\n",
    "        segment_id.append(0)\n",
    "        for token in tokens_b:\n",
    "            tokens.append(token)\n",
    "            segment_id.append(1)\n",
    "        tokens.append(\"[SEP]\")\n",
    "        segment_id.append(1)\n",
    "        input_id = tokenizer.convert_tokens_to_ids(tokens)\n",
    "        input_mask = [1] * len(input_id)\n",
    "\n",
    "        while len(input_id) < MAX_SEQ_LENGTH:\n",
    "            input_id.append(0)\n",
    "            input_mask.append(0)\n",
    "            segment_id.append(0)\n",
    "\n",
    "        input_ids.append(input_id)\n",
    "        input_masks.append(input_mask)\n",
    "        segment_ids.append(segment_id)\n",
    "    return input_ids, input_masks, segment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 261802/261802 [02:41<00:00, 1623.32it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['train_X'])):\n",
    "    l, r = data['train_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "train_ids, train_masks, segment_train = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 13395/13395 [00:08<00:00, 1575.46it/s]\n"
     ]
    }
   ],
   "source": [
    "left, right = [], []\n",
    "for i in range(len(data['test_X'])):\n",
    "    l, r = data['test_X'][i].split(' <> ')\n",
    "    left.append(l)\n",
    "    right.append(r)\n",
    "    \n",
    "test_ids, test_masks, segment_test = get_data(left, right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "epoch = 10\n",
    "batch_size = 60\n",
    "warmup_proportion = 0.1\n",
    "num_train_steps = int(len(left) / batch_size * epoch)\n",
    "num_warmup_steps = int(num_train_steps * warmup_proportion)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "        dimension_output,\n",
    "        learning_rate = 2e-5,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        model = modeling.BertModel(\n",
    "            config=bert_config,\n",
    "            is_training=True,\n",
    "            input_ids=self.X,\n",
    "            input_mask=self.input_masks,\n",
    "            token_type_ids=self.segment_ids,\n",
    "            use_one_hot_embeddings=False)\n",
    "        \n",
    "        output_layer = model.get_pooled_output()\n",
    "        self.out = tf.layers.dense(output_layer, bert_config.hidden_size)\n",
    "        self.out = tf.nn.l2_normalize(self.out, 1)\n",
    "        self.logits = tf.layers.dense(self.out,dimension_output,use_bias=False,\n",
    "                                      kernel_constraint=tf.keras.constraints.unit_norm())\n",
    "        \n",
    "        self.gamma = 64\n",
    "        self.margin = 0.25\n",
    "        self.O_p = 1 + self.margin\n",
    "        self.O_n = -self.margin\n",
    "        self.Delta_p = 1 - self.margin\n",
    "        self.Delta_n = self.margin\n",
    "        \n",
    "        self.batch_idxs = tf.expand_dims(\n",
    "          tf.range(0, self.batch_size, dtype=tf.int32), 1)  # shape [batch,1]\n",
    "        idxs = tf.concat([self.batch_idxs, tf.cast(self.Y, tf.int32)], 1)\n",
    "        sp = tf.expand_dims(tf.gather_nd(self.logits, idxs), 1)\n",
    "        mask = tf.logical_not(\n",
    "            tf.scatter_nd(idxs, tf.ones(tf.shape(idxs)[0], tf.bool),\n",
    "                          tf.shape(self.logits)))\n",
    "\n",
    "        sn = tf.reshape(tf.boolean_mask(self.logits, mask), (self.batch_size, -1))\n",
    "\n",
    "        alpha_p = tf.nn.relu(self.O_p - tf.stop_gradient(sp))\n",
    "        alpha_n = tf.nn.relu(tf.stop_gradient(sn) - self.O_n)\n",
    "\n",
    "        r_sp_m = alpha_p * (sp - self.Delta_p)\n",
    "        r_sn_m = alpha_n * (sn - self.Delta_n)\n",
    "        _Z = tf.concat([r_sn_m, r_sp_m], 1)\n",
    "        _Z = _Z * self.gamma\n",
    "        # sum all similarity\n",
    "        logZ = tf.math.reduce_logsumexp(_Z, 1, keepdims=True)\n",
    "        # remove sn_p from all sum similarity\n",
    "        self.cost = -r_sp_m * self.gamma + logZ\n",
    "        self.cost = tf.reduce_mean(self.cost[:,0])\n",
    "        \n",
    "        self.optimizer = optimization.create_optimizer(self.cost, learning_rate, \n",
    "                                                       num_train_steps, num_warmup_steps, False)\n",
    "        correct_pred = tf.equal(\n",
    "            tf.argmax(self.logits, 1, output_type = tf.int32), self.Y[:,0]\n",
    "        )\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/client/session.py:1750: UserWarning: An interactive session is already active. This can cause out-of-memory errors in some cases. You must explicitly call `InteractiveSession.close()` to release resources held by the other session(s).\n",
      "  warnings.warn('An interactive session is already active. This can '\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/optimization.py:27: The name tf.train.get_or_create_global_step is deprecated. Please use tf.compat.v1.train.get_or_create_global_step instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/optimization.py:32: The name tf.train.polynomial_decay is deprecated. Please use tf.compat.v1.train.polynomial_decay instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/bert/optimization.py:70: The name tf.trainable_variables is deprecated. Please use tf.compat.v1.trainable_variables instead.\n",
      "\n",
      "INFO:tensorflow:Restoring parameters from cased_L-12_H-768_A-12/bert_model.ckpt\n"
     ]
    }
   ],
   "source": [
    "dimension_output = 2\n",
    "learning_rate = 2e-5\n",
    "\n",
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model(\n",
    "    dimension_output,\n",
    "    learning_rate\n",
    ")\n",
    "\n",
    "sess.run(tf.global_variables_initializer())\n",
    "var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')\n",
    "saver = tf.train.Saver(var_list = var_lists)\n",
    "saver.restore(sess, BERT_INIT_CHKPNT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['contradiction', 'entailment']\n",
    "\n",
    "train_Y = data['train_Y']\n",
    "test_Y = data['test_Y']\n",
    "\n",
    "train_Y = [labels.index(i) for i in train_Y]\n",
    "test_Y = [labels.index(i) for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop:  91%|█████████▏| 3992/4364 [34:19<03:10,  1.95it/s, accuracy=0.883, cost=6.12]IOPub message rate exceeded.\n",
      "The notebook server will temporarily stop sending output\n",
      "to the client in order to avoid crashing it.\n",
      "To change this limit, set the config variable\n",
      "`--NotebookApp.iopub_msg_rate_limit`.\n",
      "\n",
      "Current values:\n",
      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
      "NotebookApp.rate_limit_window=3.0 (secs)\n",
      "\n",
      "train minibatch loop: 100%|██████████| 4364/4364 [37:26<00:00,  1.94it/s, accuracy=0.955, cost=3.35]\n",
      "test minibatch loop: 100%|██████████| 224/224 [00:40<00:00,  5.52it/s, accuracy=1, cost=0.312]   "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "time taken: 2287.008446931839\n",
      "epoch: 1, training loss: 5.740891, training acc: 0.911401, valid loss: 6.376333, valid acc: 0.899777\n",
      "\n",
      "break epoch:2\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "EARLY_STOPPING, CURRENT_CHECKPOINT, CURRENT_ACC, EPOCH = 1, 0, 0, 0\n",
    "\n",
    "while True:\n",
    "    lasttime = time.time()\n",
    "    if CURRENT_CHECKPOINT == EARLY_STOPPING:\n",
    "        print('break epoch:%d\\n' % (EPOCH))\n",
    "        break\n",
    "\n",
    "    train_acc, train_loss, test_acc, test_loss = [], [], [], []\n",
    "    pbar = tqdm(\n",
    "        range(0, len(train_ids), batch_size), desc = 'train minibatch loop'\n",
    "    )\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_ids))\n",
    "        batch_x = train_ids[i: index]\n",
    "        batch_masks = train_masks[i: index]\n",
    "        batch_segment = segment_train[i: index]\n",
    "        batch_y = train_Y[i: index]\n",
    "        batch_y = np.expand_dims(batch_y,1)\n",
    "        acc, cost, _ = sess.run(\n",
    "            [model.accuracy, model.cost, model.optimizer],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        assert not np.isnan(cost)\n",
    "        train_loss.append(cost)\n",
    "        train_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "    \n",
    "    pbar = tqdm(range(0, len(test_ids), batch_size), desc = 'test minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_ids))\n",
    "        batch_x = test_ids[i: index]\n",
    "        batch_masks = test_masks[i: index]\n",
    "        batch_segment = segment_test[i: index]\n",
    "        batch_y = test_Y[i: index]\n",
    "        batch_y = np.expand_dims(batch_y,1)\n",
    "        acc, cost = sess.run(\n",
    "            [model.accuracy, model.cost],\n",
    "            feed_dict = {\n",
    "                model.Y: batch_y,\n",
    "                model.X: batch_x,\n",
    "                model.segment_ids: batch_segment,\n",
    "                model.input_masks: batch_masks\n",
    "            },\n",
    "        )\n",
    "        test_loss.append(cost)\n",
    "        test_acc.append(acc)\n",
    "        pbar.set_postfix(cost = cost, accuracy = acc)\n",
    "\n",
    "    train_loss = np.mean(train_loss)\n",
    "    train_acc = np.mean(train_acc)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_acc = np.mean(test_acc)\n",
    "    \n",
    "    if test_acc > CURRENT_ACC:\n",
    "        print(\n",
    "            'epoch: %d, pass acc: %f, current acc: %f'\n",
    "            % (EPOCH, CURRENT_ACC, test_acc)\n",
    "        )\n",
    "        CURRENT_ACC = test_acc\n",
    "        CURRENT_CHECKPOINT = 0\n",
    "    else:\n",
    "        CURRENT_CHECKPOINT += 1\n",
    "        \n",
    "    print('time taken:', time.time() - lasttime)\n",
    "    print(\n",
    "        'epoch: %d, training loss: %f, training acc: %f, valid loss: %f, valid acc: %f\\n'\n",
    "        % (EPOCH, train_loss, train_acc, test_loss, test_acc)\n",
    "    )\n",
    "    EPOCH += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "test minibatch loop: 100%|██████████| 224/224 [00:40<00:00,  5.51it/s, accuracy=1, cost=0.3]     \n"
     ]
    }
   ],
   "source": [
    "test_acc, test_loss = [], []\n",
    "\n",
    "pbar = tqdm(range(0, len(test_ids), batch_size), desc = 'test minibatch loop')\n",
    "for i in pbar:\n",
    "    index = min(i + batch_size, len(test_ids))\n",
    "    batch_x = test_ids[i: index]\n",
    "    batch_masks = test_masks[i: index]\n",
    "    batch_segment = segment_test[i: index]\n",
    "    batch_y = test_Y[i: index]\n",
    "    batch_y = np.expand_dims(batch_y,1)\n",
    "    acc, cost = sess.run(\n",
    "        [model.accuracy, model.cost],\n",
    "        feed_dict = {\n",
    "            model.Y: batch_y,\n",
    "            model.X: batch_x,\n",
    "            model.segment_ids: batch_segment,\n",
    "            model.input_masks: batch_masks\n",
    "        },\n",
    "    )\n",
    "    test_loss.append(cost)\n",
    "    test_acc.append(acc)\n",
    "    pbar.set_postfix(cost = cost, accuracy = acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6.366118, 0.8990327)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_loss = np.mean(test_loss)\n",
    "test_acc = np.mean(test_acc)\n",
    "test_loss, test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
